#include "framework/unittest.hh"

#include "../core/plato_node.hh"
#include "../test/struct.hh"

using namespace plato;

class PlatoLogicManagerImpl : public PlatoLogicManager {
  std::unordered_map<PlatoNodeID, PlatoNodeLogicPtr> logic_map_;

public:
  PlatoLogicManagerImpl() {}
  virtual ~PlatoLogicManagerImpl() {}
  virtual auto get_logic_instance(PlatoNodeID static_id)
      -> PlatoNodeLogicPtr override {
    auto it = logic_map_.find(static_id);
    if (it == logic_map_.end()) {
      return nullptr;
    }
    return it->second;
  }
  auto add_logic(PlatoNodeID static_id, PlatoNodeLogicPtr logic_ptr) -> void {
    logic_map_.emplace(static_id, logic_ptr);
  }
};

class PlatoNodeLogic_node1 : public PlatoNodeLogic {
public:
  PlatoNodeLogic_node1() {}
  virtual ~PlatoNodeLogic_node1() {}
  virtual auto do_logic(PlatoNode *node, PlatoFlowStatus &status)
      -> ExecResult {
    status = PlatoFlowStatus::OK;
    auto *pin_ptr = node->get_output_pin(1);
    auto int_ptr = pin_ptr->dyn_cast<Int>();
    *int_ptr = 2;
    return {0};
  }
  virtual auto do_event(PlatoNode *node, VariablePtr var_ptr) -> void {}
};

class PlatoNodeLogic_node2 : public PlatoNodeLogic {
public:
  bool done{false};

public:
  PlatoNodeLogic_node2() {}
  virtual ~PlatoNodeLogic_node2() {}
  virtual auto do_logic(PlatoNode *node, PlatoFlowStatus &status)
      -> ExecResult {
    done = true;
    status = PlatoFlowStatus::OK;
    return ExecResult();
  }
  virtual auto do_event(PlatoNode *node, VariablePtr var_ptr) -> void {}
};

FIXTURE_BEGIN(TestPlatoNode)

CASE(TestFlow1) {
  auto domain_ptr = new_domain(1);

  auto node_ptr1 = new_node(domain_ptr, 1, 1, PlatoNodeSyncType::NONE);
  auto node1_pin_ptr1 = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);
  auto var_node1_pin2 = domain_ptr->New<Int>(PlatoVariableSyncType::NONE, 1);
  auto node1_pin_ptr2 = new_pin(domain_ptr, PlatoPinType::VAR, var_node1_pin2);
  node_ptr1->add_output(node1_pin_ptr1);
  node_ptr1->add_output(node1_pin_ptr2);

  auto node_ptr2 = new_node(domain_ptr, 2, 2, PlatoNodeSyncType::NONE);
  auto node2_pin_ptr1 = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);
  auto var_node2_pin2 = domain_ptr->New<Int>(PlatoVariableSyncType::NONE);
  auto node2_pin_ptr2 = new_pin(domain_ptr, PlatoPinType::VAR, var_node1_pin2);
  node_ptr2->add_input(node1_pin_ptr1);
  node_ptr2->add_input(node1_pin_ptr2);

  auto flow = new_flow(1, domain_ptr);
  flow->add_node(node_ptr1);
  flow->add_node(node_ptr2);
  flow->add_link({1, 0}, {2, 0});
  flow->add_link({1, 1}, {2, 1});
  flow->set_start_node(node_ptr1);

  auto logic_manager_ptr = std::make_shared<PlatoLogicManagerImpl>();
  set_logic_manager(logic_manager_ptr);
  auto logic2 = std::make_shared<PlatoNodeLogic_node2>();
  logic_manager_ptr->add_logic(1, std::make_shared<PlatoNodeLogic_node1>());
  logic_manager_ptr->add_logic(2, logic2);

  flow->complete();

  flow->update();

  ASSERT_TRUE(logic2->done);
}

auto new_client_flow(DomainPtr domain_ptr) -> PlatoFlowPtr {
  auto node_ptr1 = new_node(domain_ptr, 1, 1, PlatoNodeSyncType::CLIENT_SIDE);
  auto node1_pin_ptr1 = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);
  auto var_node1_pin2 = domain_ptr->New<Int>(PlatoVariableSyncType::NONE, 1);
  auto node1_pin_ptr2 = new_pin(domain_ptr, PlatoPinType::VAR, var_node1_pin2);
  node_ptr1->add_output(node1_pin_ptr1);
  node_ptr1->add_output(node1_pin_ptr2);

  auto node_ptr2 = new_node(domain_ptr, 2, 2, PlatoNodeSyncType::NONE);
  auto node2_pin_ptr1 = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);
  auto var_node2_pin2 = domain_ptr->New<Int>(PlatoVariableSyncType::NONE);
  auto node2_pin_ptr2 = new_pin(domain_ptr, PlatoPinType::VAR, var_node1_pin2);
  node_ptr2->add_input(node1_pin_ptr1);
  node_ptr2->add_input(node1_pin_ptr2);

  auto flow1 = new_flow(1, domain_ptr);
  flow1->add_node(node_ptr1);
  flow1->add_node(node_ptr2);
  flow1->add_link({1, 0}, {2, 0});
  flow1->add_link({1, 1}, {2, 1});
  flow1->set_start_node(node_ptr1);

  return flow1;
}

auto new_server_flow(DomainPtr domain_ptr) -> PlatoFlowPtr {
  auto node_ptr1 = new_node(domain_ptr, 1, 1, PlatoNodeSyncType::SERVER_SIDE);
  auto node1_pin_ptr1 = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);
  auto var_node1_pin2 = domain_ptr->New<Int>(PlatoVariableSyncType::NONE, 1);
  auto node1_pin_ptr2 = new_pin(domain_ptr, PlatoPinType::VAR, var_node1_pin2);
  node_ptr1->add_output(node1_pin_ptr1);
  node_ptr1->add_output(node1_pin_ptr2);

  auto node_ptr2 = new_node(domain_ptr, 2, 2, PlatoNodeSyncType::NONE);
  auto node2_pin_ptr1 = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);
  auto var_node2_pin2 = domain_ptr->New<Int>(PlatoVariableSyncType::NONE);
  auto node2_pin_ptr2 = new_pin(domain_ptr, PlatoPinType::VAR, var_node1_pin2);
  node_ptr2->add_input(node1_pin_ptr1);
  node_ptr2->add_input(node1_pin_ptr2);

  auto flow1 = new_flow(1, domain_ptr);
  flow1->add_node(node_ptr1);
  flow1->add_node(node_ptr2);
  flow1->add_link({1, 0}, {2, 0});
  flow1->add_link({1, 1}, {2, 1});
  flow1->set_start_node(node_ptr1);

  return flow1;
}

CASE(TestFlow2) {
  auto logic_manager_ptr = std::make_shared<PlatoLogicManagerImpl>();
  set_logic_manager(logic_manager_ptr);
  auto logic1 = std::make_shared<PlatoNodeLogic_node1>();
  logic_manager_ptr->add_logic(1, logic1);
  auto logic2 = std::make_shared<PlatoNodeLogic_node2>();
  logic_manager_ptr->add_logic(2, logic2);

  auto domain1 = new_domain(1);
  auto domain2 = new_domain(2);

  auto client_flow = new_client_flow(domain1);
  auto server_flow = new_server_flow(domain2);
  client_flow->update();
  client_flow->get_domain()->get_stream() >>
      server_flow->get_domain()->get_stream();
  client_flow->get_ostream() >> server_flow->get_istream();
  client_flow->update();
  server_flow->update();
  server_flow->get_domain()->get_stream() >>
      client_flow->get_domain()->get_stream();
  server_flow->get_ostream() >> client_flow->get_istream();
  client_flow->update();

  ASSERT_TRUE(logic2->done);
}

FIXTURE_END(TestPlatoNode)
